"""
PAWS-X: A Cross-lingual Adversarial Dataset for Paraphrase Identification
https://arxiv.org/abs/1908.11828

The dataset consists of 23,659 human translated PAWS evaluation pairs and
296,406 machine translated training pairs in 6 typologically distinct languages.

Examples are adapted from  PAWS-Wiki

Prompt format (same as in mGPT):

"<s>" + sentence1 + ", right? " + mask + ", " + sentence2 + "</s>",

where mask is the string that matches the label:

Yes, No.

Example:

<s> The Tabaci River is a tributary of the River Leurda in Romania, right? No, The Leurda River is a tributary of the River Tabaci in Romania.</s>

Language specific prompts are translated word-by-word with Google Translate
and may differ from the ones used by mGPT and XGLM (they do not provide their prompts).

Homepage: https://github.com/google-research-datasets/paws/tree/master/pawsx
"""
from lm_eval.base import Task, rf
from lm_eval.metrics import mean
from lm_eval import utils

_CITATION = """
@inproceedings{yang-etal-2019-paws,
    title = "{PAWS}-{X}: A Cross-lingual Adversarial Dataset for Paraphrase Identification",
    author = "Yang, Yinfei  and
      Zhang, Yuan  and
      Tar, Chris  and
      Baldridge, Jason",
    booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)",
    month = nov,
    year = "2019",
    address = "Hong Kong, China",
    publisher = "Association for Computational Linguistics",
    url = "https://aclanthology.org/D19-1382",
    doi = "10.18653/v1/D19-1382",
    pages = "3687--3692",
}"""


class PAWSXBase(Task):
    VERSION = 0
    DATASET_PATH = "paws-x"
    DATASET_NAME = None  # 'en'

    YES = None  # 'Yes'
    NO = None  # 'No'
    QUESTION_WORD = None  # 'right'

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return True

    def training_docs(self):
        return self.dataset["train"]

    def validation_docs(self):
        return self.dataset["validation"]

    def test_docs(self):
        return self.dataset["test"]

    def doc_to_text(self, doc):
        # same as in mGPT paper
        return (
            doc["sentence1"]
            + ", "
            + self.QUESTION_WORD
            + "? [MASK], "
            + doc["sentence2"]
        )

    def doc_to_target(self, doc):
        return " " + [self.YES, self.NO][doc["label"]]

    def doc_to_fewshot_prompt(self, doc):

        prompt = self.doc_to_text(doc)
        return prompt.replace("[MASK]", self.doc_to_target(doc)[1:])

    def construct_requests(self, doc, ctx):
        """Uses RequestFactory to construct Requests and returns an iterable of
        Requests which will be sent to the LM.

        :param doc:
            The document as returned from training_docs, validation_docs, or
            test_docs.
        :param ctx: str
            The context string, generated by fewshot_context. This includes the natural
            language description, as well as the few shot examples, and the question
            part of the document for `doc`.
        """

        ll_yes = rf.loglikelihood_rolling(ctx.replace("[MASK]", self.YES))
        ll_no = rf.loglikelihood_rolling(ctx.replace("[MASK]", self.NO))

        return ll_yes, ll_no

    def process_results(self, doc, results):
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
        the metric for that one document

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
        """
        ll_yes, ll_no = results

        pred = ll_yes > ll_no

        true_label = doc["label"]

        return {
            "acc": pred == true_label,
        }

    def aggregation(self):
        """
        :returns: {str: [metric_score] -> float}
            A dictionary where keys are the names of submetrics and values are
            functions that aggregate a list of metric scores
        """
        return {
            "acc": mean,
        }

    def higher_is_better(self):
        return {"acc": True}

    @utils.positional_deprecated
    def fewshot_context(
        self, doc, num_fewshot, provide_description=None, rnd=None, description=None
    ):
        """Returns a fewshot context string that is made up of a prepended description
        (if provided), the `num_fewshot` number of examples, and an appended prompt example.

        :param doc: str
            The document as returned from training_docs, validation_docs, or test_docs.
        :param num_fewshot: int
            The number of fewshot examples to provide in the returned context string.
        :param provide_description: bool
            Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
        :param rnd: random.Random
            The pseudo-random number generator used to randomly sample examples.
            WARNING: This is currently a required arg although it's optionalized with a default `None`.
        :param description: str
            The task's description that will be prepended to the fewshot examples.
        :returns: str
            The fewshot context.
        """
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
        assert not provide_description, (
            "The `provide_description` arg will be removed in future versions. To prepend "
            "a custom description to the context, supply the corresponding string via the "
            "`description` arg."
        )
        if provide_description is not None:
            # nudge people to not specify it at all
            print(
                "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
            )

        description = description + "\n\n" if description else ""

        if num_fewshot == 0:
            labeled_examples = ""
        else:
            # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
            if self.has_training_docs():
                fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
            else:
                if self._fewshot_docs is None:
                    self._fewshot_docs = list(
                        self.validation_docs()
                        if self.has_validation_docs()
                        else self.test_docs()
                    )

                fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)

                # get rid of the doc that's the one we're evaluating, if it's in the fewshot
                fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]

            labeled_examples = (
                "\n\n".join(
                    [
                        # self.doc_to_text(doc) + self.doc_to_target(doc)
                        self.doc_to_fewshot_prompt(doc)
                        for doc in fewshotex
                    ]
                )
                + "\n\n"
            )

        example = self.doc_to_text(doc)
        return description + labeled_examples + example


class PAWSX_en(PAWSXBase):
    DATASET_NAME = "en"
    YES = "Yes"
    NO = "No"
    QUESTION_WORD = "right"


class PAWSX_de(PAWSXBase):
    DATASET_NAME = "de"
    YES = "Ja"
    NO = "Nein"
    QUESTION_WORD = "richtig"


class PAWSX_fr(PAWSXBase):
    DATASET_NAME = "fr"
    YES = "Oui"
    NO = "No"
    QUESTION_WORD = "right"


class PAWSX_es(PAWSXBase):
    DATASET_NAME = "es"
    YES = "Sí"
    NO = "No"
    QUESTION_WORD = "verdad"


class PAWSX_ja(PAWSXBase):
    DATASET_NAME = "ja"
    YES = "はい"
    NO = "いいえ"
    QUESTION_WORD = "ですね"


class PAWSX_ko(PAWSXBase):
    DATASET_NAME = "ko"
    YES = "예"
    NO = "아니요"
    QUESTION_WORD = "맞죠"


class PAWSX_zh(PAWSXBase):
    DATASET_NAME = "zh"
    YES = "是"
    NO = "不是"
    QUESTION_WORD = "对吧"


LANGS = [
    "en",
    "de",
    "es",
    "fr",
    "ja",
    "ko",
    "zh",
]

LANG_CLASSES = [
    PAWSX_en,
    PAWSX_de,
    PAWSX_es,
    PAWSX_fr,
    PAWSX_ja,
    PAWSX_ko,
    PAWSX_zh,
]


def construct_tasks():
    tasks = {}
    for lang, lang_class in zip(LANGS, LANG_CLASSES):
        tasks[f"pawsx_{lang}"] = lang_class
    return tasks
